# -*- coding: utf-8 -*-
"""
This code is used to generate the results for Dataset 3

@author: 
"""
import gzip
import json
from beep import structure
import os
import numpy as np
from numpy import gradient
import torch
from sklearn.model_selection import train_test_split
from scipy.interpolate import interp1d
import matplotlib.pyplot as plt
import random
import seaborn as sns
from matplotlib import rcParams
import pandas as pd
from matplotlib.colors import Normalize
from mpl_toolkits.axes_grid1.inset_locator import inset_axes
import matplotlib.colors as mcolors
from scipy.signal import savgol_filter
from scipy.optimize import differential_evolution
from sklearn.metrics import mean_squared_error
from joblib import Parallel, delayed
import pickle
from scipy.spatial import distance
from joblib import Parallel, delayed
from sklearn.linear_model import MultiTaskLasso
from sklearn.preprocessing import StandardScaler
from sklearn.model_selection import train_test_split
from sklearn.metrics import mean_squared_error, r2_score, mean_absolute_error
from scipy.signal import find_peaks
import joblib

import warnings
warnings.filterwarnings("ignore")

def set_random_seed(seed):
    # seed setting
    random.seed(seed)
    np.random.seed(seed)
    torch.manual_seed(seed)
    torch.cuda.manual_seed(seed)
    torch.backends.cudnn.deterministic = True
    torch.backends.cudnn.benchmark = False

plt.rcParams['font.family'] = 'Times New Roman'
rcParams['mathtext.fontset'] = 'custom'
rcParams['mathtext.rm'] = 'Times New Roman'
rcParams['mathtext.it'] = 'Times New Roman:italic'
rcParams['mathtext.bf'] = 'Times New Roman:bold'
plt.rcParams['xtick.direction'] = 'in'
plt.rcParams['ytick.direction'] = 'in'
plt.rcParams['axes.edgecolor'] = 'black'
plt.rcParams['axes.linewidth'] = 0.8
plt.rcParams['font.size'] = 12

#%% Load half cell data and fitting function construction

# c/5 
OCPn_data = pd.read_csv('anode_SiO_Gr_discharge_Cover5_smoothed_dvdq_JS.csv')
OCPp_data = pd.read_csv('cathode_NCA_discharge_Cover5_smoothed_dvdq_JS.csv')

OCPn_SOC = OCPn_data['SOC_linspace'].values
OCPn_V = OCPn_data['Voltage'].values
OCPp_SOC = OCPp_data['SOC_linspace'].values
OCPp_V = OCPp_data['Voltage'].values[::-1].copy()  # 

# 
OCP_p = interp1d(OCPp_SOC, OCPp_V, kind='cubic', fill_value='extrapolate', bounds_error=False)
OCP_n = interp1d(OCPn_SOC, OCPn_V, kind='cubic', fill_value='extrapolate', bounds_error=False)



# c/40
def openBioFile(path):
    with open(path, 'rt', encoding="ISO-8859-1") as data:
        next(data)
        headerLines = next(data)
        nbHeaderLines = int(headerLines.split(":")[-1][1:4])
        df = pd.read_csv(path, sep="\t", skiprows=nbHeaderLines-1, encoding="ISO-8859-1")
    return df


ne_data_Co40 = openBioFile(r"ZABC-A-097_Gr3_Co40_20220819_CD7.txt")
pe_data_Co40 = openBioFile(r"ZABC-A-097_NCA8_Co40_20221020_CD8.txt")

cc_charge = 4
cc_discharge = 7
pe_charge_Co40 = pe_data_Co40[(pe_data_Co40["Ns"] == cc_discharge)]
ne_charge_Co40 = ne_data_Co40[(ne_data_Co40["Ns"] == cc_charge)]
pe_capacity = pe_charge_Co40["Capacity/mA.h"].values
pe_voltage = pe_charge_Co40["Ecell/V"].values
pe_voltage = pe_voltage[::-1]
pe_voltage_smoothed = savgol_filter(pe_voltage,window_length=251,polyorder=1)
ne_capacity = ne_charge_Co40["Capacity/mA.h"].values
ne_voltage = ne_charge_Co40["Ecell/V"].values
ne_voltage = ne_voltage[::-1]
ne_voltage_smoothed = savgol_filter(ne_voltage,window_length=251,polyorder=1)



x_new = np.linspace(0, 1, 1000)
interp_ocv_c = interp1d(np.linspace(0, 1, len(pe_capacity)), pe_capacity, kind='linear')
interp_ocv_v = interp1d(np.linspace(0, 1, len(pe_voltage_smoothed)), pe_voltage_smoothed, kind='linear')
pe_capacity = interp_ocv_c(x_new)
pe_voltage_smoothed = interp_ocv_v(x_new)

interp_ocv_c = interp1d(np.linspace(0, 1, len(ne_capacity)), ne_capacity, kind='linear')
interp_ocv_v = interp1d(np.linspace(0, 1, len(ne_voltage_smoothed)), ne_voltage_smoothed, kind='linear')
ne_capacity = interp_ocv_c(x_new)
ne_voltage_smoothed = interp_ocv_v(x_new)

OCPn_SOC_40 = ne_capacity/ne_capacity[-1] # 
OCPn_V_40 =ne_voltage_smoothed
OCPp_SOC_40 = pe_capacity/pe_capacity[-1]
OCPp_V_40 = pe_voltage_smoothed[::-1].copy()

# 
OCP_p_40 = interp1d(OCPp_SOC_40, OCPp_V_40, kind='cubic', fill_value='extrapolate', bounds_error=False)
OCP_n_40 = interp1d(OCPn_SOC_40, OCPn_V_40, kind='cubic', fill_value='extrapolate', bounds_error=False)


#%% only needed for the first run to extract data


def read_dynamic_data(file_name, num_points=1000):
    inputs = []
    outputs = []
    efcs = []
    
    dynamic_data = pd.read_csv(file_name, low_memory=False)
    Capacity_all = dynamic_data['Normalized Capacity (nominal capacity unit)'].values
    Voltage_all = dynamic_data['Volts'].values
    EFC_all = np.zeros_like(Capacity_all)  
    cumulative_offset = 0  
    for i in range(len(Capacity_all)):
        if abs(Capacity_all[i])>1000:
            Capacity_all[i] = Capacity_all[i-1]
            
        if i > 0 and abs(Capacity_all[i]) < abs(Capacity_all[i - 1]):  
            cumulative_offset += abs(Capacity_all[i - 1])  
        EFC_all[i] = abs(Capacity_all[i]) + cumulative_offset  
    EFC_all = EFC_all / 2  
    
    grouped = dynamic_data.groupby("Cyc#")
    Voltage_cycles = [group["Volts"].values for _, group in grouped]
    Capacity_cycles = [group["Normalized Capacity (nominal capacity unit)"].values for _, group in grouped]
    Step_cycles = [group["Step"].values for _, group in grouped]
    Current_cycles = [group["Normalized Current (C-rate)"].values for _, group in grouped]
    Time_cycles = [group["Test (Sec)"].values for _, group in grouped]
    EFC_cycles = [EFC_all[group.index] for _, group in grouped]
    
    selected_cycles = []
    for i in range(len(Time_cycles)-1):
        if len(Time_cycles[i]) > 20000 and (Time_cycles[i][-1] - Time_cycles[i][0]) > 200000 and len(Time_cycles[i+1]) > 2000:
            selected_cycles.append(i)
    
    for rpt_data_cycle in selected_cycles:
       
        I = np.array(Current_cycles[rpt_data_cycle])
        condition = (I == -0.025)
        indices = np.where(condition)[0]
        C040_segment = []
        current_segment = [indices[0]] if len(indices) > 0 else []
        for i in range(1, len(indices)):
            if indices[i] == indices[i - 1] + 1: 
                current_segment.append(indices[i])
            else:
              
                if len(current_segment) > len(C040_segment):
                    C040_segment = current_segment
                current_segment = [indices[i]]
        if len(current_segment) > len(C040_segment):
            C040_segment = current_segment
        condition = (np.array(Step_cycles[rpt_data_cycle]) == np.array(Step_cycles[rpt_data_cycle][C040_segment[0]]))
        C040_segment = np.where(condition)[0]
        
        for i in range(len(C040_segment)):
            # print(i)
            if np.array(Current_cycles[rpt_data_cycle])[C040_segment[i]]<0 and np.array(Capacity_cycles[rpt_data_cycle])[C040_segment[i]]>0:
                break
        
        for j in range(i,len(C040_segment)-1):
            if np.array(Current_cycles[rpt_data_cycle])[C040_segment[j]]>=0 and np.array(Voltage_cycles[rpt_data_cycle])[C040_segment[j]]<=2.801:
                break
            
        for k in range(i,j):
            if np.array(Voltage_cycles[rpt_data_cycle])[C040_segment[k+1]]>np.array(Voltage_cycles[rpt_data_cycle])[C040_segment[k]] and np.array(Voltage_cycles[rpt_data_cycle])[C040_segment[k]]<=2.801:
                break
        
        
        C040_segment = C040_segment[i:k]
        discharge_capacity = np.array(Capacity_cycles[rpt_data_cycle])[C040_segment]
        discharge_voltage = np.array(Voltage_cycles[rpt_data_cycle])[C040_segment]/4.2
        EFCs = np.array(EFC_cycles[rpt_data_cycle])[C040_segment]
        I = I[C040_segment]
        
        x_new = np.linspace(0, 1, num_points)
        interp_ocv_v = interp1d(np.linspace(0, 1, len(discharge_voltage)), discharge_voltage, kind='linear')
        interp_ocv_q = interp1d(np.linspace(0, 1, len(discharge_capacity)), discharge_capacity, kind='linear')
        discharge_voltage_interp = interp_ocv_v(x_new)
        discharge_capacity_interp = interp_ocv_q(x_new)
        discharge_capacity_diff = discharge_capacity_interp[1:] - discharge_capacity_interp[ :-1]
        
        
        input_data = np.stack([discharge_voltage_interp, discharge_capacity_interp, I[50]*np.ones(len(x_new))/4.84], axis=-1)
        output_data = np.stack([discharge_voltage_interp, discharge_capacity_interp], axis=-1)
        efc_data = np.mean(EFCs)
        # plt.plot(discharge_voltage_interp*4.2,discharge_capacity_interp)
        
        inputs.append(input_data)
        outputs.append(output_data)    
        efcs.append(efc_data)       
    inputs = np.array(inputs)  # Shape (num_samples, num_points, num_features)
    outputs = np.array(outputs)  #(num_samples, num_points)
    efcs = np.array(efcs)
    
    return inputs, outputs, efcs
        
    
folder_loc = '..\\data_code\\dynamicdata'
folder_loc = os.path.abspath(folder_loc)
file_list = [f for f in os.listdir(folder_loc)]

all_inputs, all_outputs = [], []   
all_cells, all_efcs = [], []
for i in range(0,len(file_list)): # train_batteries all_batteries 
    battery = file_list[i]
    # battery = 'Publishing_data_raw_data_cell_093.csv'
    print("read cell", battery)
    file_name = os.path.join(folder_loc, battery)
    inputs_cell, outputs_cell, efcs_cell = read_dynamic_data(file_name)
    all_inputs.append(inputs_cell)
    all_outputs.append(outputs_cell)
    all_cells.append((battery,"C/40_Cycle"))
    all_efcs.append(efcs_cell)


#%%

def objective_function(params, c_rate, measure_Q, measure_V):
    Cp, Cn, NP_offset = params
    y0 = 0
    x0 = NP_offset
    SOC_p = y0 + measure_Q / Cp
    SOC_n = x0 - measure_Q / Cn
    if c_rate == 'C/40_Cycle':
        Up = OCP_p_40(SOC_p)
        Un = OCP_n_40(SOC_n)
    elif c_rate == 'C/5_Cycle':
        Up = OCP_p(SOC_p)
        Un = OCP_n(SOC_n)
    else:
        raise ValueError("Unsupported c_rate type.")

    fitted_Voc = Up - Un
    regularization = 0.01 * (Cp**2 + Cn**2 + NP_offset**2)
    
    measure_V_matrix = np.vstack((measure_Q, measure_V)).T  # (Q, V) measure
    fitted_Voc_matrix = np.vstack((measure_Q, fitted_Voc)).T  # (Q, V) fit
    error_matrix = distance.cdist(measure_V_matrix, fitted_Voc_matrix, "euclidean")
    error_vector = error_matrix.min(axis=1)
    error = error_vector.mean()
    total_loss = error+regularization
    
    return total_loss


def get_bounds_by_dof(dof):
    if dof == 4:
        lb = [0.2, 0.2, 0, 0]     # Cp, Cn, x0, y0
        ub = [1.2, 1.2, 1.0, 1.0]
        names = ['Cp', 'Cn', 'x0', 'y0']
    elif dof == 3:
        lb = [0.2, 0.2, 0]        # Cp, Cn, NP_offset
        ub = [1.2, 1.2, 1.0]
        names = ['Cp', 'Cn', 'NP_offset']
    elif dof == 2:
        lb = [0.2, 0]             # NP_ratio, NP_offset
        ub = [1.2, 1.0]
        names = ['NP_ratio', 'NP_offset']
    else:
        raise ValueError("DOF must be 2, 3 or 4")
    return lb, ub, names


def run_de(trial, Q, V, rate):
    bounds = list(zip(*get_bounds_by_dof(3)[:2]))
    np.random.seed(trial)
    result = differential_evolution(
        objective_function, bounds,
        args=(rate, Q, V),
        popsize=20, maxiter=200,
        # init="latinhypercube",  #
        # mutation=0.8,
        # recombination=0.7,
          # tol=1e-5,
        seed=trial
    )
    return result.x, result.fun



def optimize_cycle(cycle_index, cell_inputs, cell_outputs, cell_efcs, cell_rate, cell_name):
    measure_Q = cell_inputs[cycle_index, :, 1]
    measure_V = cell_outputs[cycle_index, :, 0] * 4.2

    # 
    num_trials = 5  # 10
    results = Parallel(n_jobs=5)(delayed(run_de)(trial, measure_Q, measure_V, cell_rate) 
                                 for trial in range(num_trials))
    # 
    best_params = None
    best_fopt = float('inf')
    for optimized_params, fopt in results:
        if fopt < best_fopt:
            best_fopt = fopt
            best_params = optimized_params
    
    best_params = [float(p) for p in best_params]
    best_fopt = float(best_fopt)
    # DOF
    Cp_opt, Cn_opt, x0_opt, y0_opt = None, None, None, None
    Cp_opt, Cn_opt, NP_offset = best_params
    y0_opt = 0
    x0_opt = NP_offset
    
    SOC_p_fit = y0_opt + measure_Q / Cp_opt
    SOC_n_fit = x0_opt - measure_Q / Cn_opt
 
    if cell_rate == 'C/5_Cycle':
        Up_fit = OCP_p(SOC_p_fit)
        Un_fit = OCP_n(SOC_n_fit)
    elif cell_rate == 'C/40_Cycle':
        Up_fit = OCP_p_40(SOC_p_fit)
        Un_fit = OCP_n_40(SOC_n_fit)
    else:
        raise ValueError("Unsupported cell_rate")
 
    fitted_Voc = Up_fit - Un_fit

    return {
        'Cp_opt': Cp_opt,
        'Cn_opt': Cn_opt,
        'x0_opt': x0_opt,
        'y0_opt': y0_opt,
        'fitted_Voc': fitted_Voc,
        'cell_cap': cell_inputs[cycle_index, -1, 1],
        'Cq': measure_Q[-1]
    }


def optimize_cell(cell_index):
    print(f"Processing cell {cell_index + 1}/{len(all_outputs)}: {all_cells[cell_index]}")
    cell_inputs = all_inputs[cell_index]
    cell_outputs = all_outputs[cell_index]
    cell_rate = all_cells[cell_index][1]
    cell_name = all_cells[cell_index][0]
    cell_efcs = all_efcs[cell_index]
    # 
    results = Parallel(n_jobs=64)(delayed(optimize_cycle)(i, cell_inputs, cell_outputs,cell_efcs, cell_rate, cell_name) 
                                  for i in range(len(cell_inputs)))

    #
    cell_Cp_opt = [result['Cp_opt'] for result in results]
    cell_Cn_opt = [result['Cn_opt'] for result in results]
    cell_x0_opt = [result['x0_opt'] for result in results]
    cell_y0_opt = [result['y0_opt'] for result in results]
    cell_OCV_fit = [result['fitted_Voc'] for result in results]
    cell_cell_cap = [result['cell_cap'] for result in results]
    cell_Cq = [result['Cq'] for result in results]

    return cell_Cp_opt, cell_Cn_opt, cell_x0_opt, cell_y0_opt, cell_OCV_fit, cell_cell_cap, cell_Cq

# 
all_Cp_opt, all_Cn_opt, all_x0_opt, all_y0_opt = [], [], [], []
all_OCV_fit = []
all_cell_cap = []
all_Cq = []

results = Parallel(n_jobs=56)(delayed(optimize_cell)(c_idx) for c_idx in range(len(all_outputs)))

# 
for result in results:
    cell_Cp_opt, cell_Cn_opt, cell_x0_opt, cell_y0_opt, cell_OCV_fit, cell_cell_cap, cell_Cq = result
    all_Cp_opt.append(cell_Cp_opt)
    all_Cn_opt.append(cell_Cn_opt)
    all_x0_opt.append(cell_x0_opt)
    all_y0_opt.append(cell_y0_opt)
    all_OCV_fit.append(cell_OCV_fit)
    all_cell_cap.append(cell_cell_cap)
    all_Cq.append(cell_Cq)


data = {
    "all_Cp_opt": all_Cp_opt,
    "all_Cn_opt": all_Cn_opt,
    "all_x0_opt": all_x0_opt,
    "all_y0_opt": all_y0_opt,
    "all_Cq": all_Cq,
    "all_OCV_fit": all_OCV_fit,
    "all_cell_cap": all_cell_cap,
    "all_cell_ocv": all_outputs,
    "all_cell_vmea": all_inputs,
    "all_cells": all_cells,
    "all_efcs": all_efcs
}

with open("dynamic_extract_data_de.pkl", "wb") as f:
    pickle.dump(data, f)

print("data saved as: dynamic_extract_data_de.pkl")

#%%
norminal_c = 1
with open("dynamic_extract_data_de.pkl", "rb") as f:
    data = pickle.load(f)
    
plt.rcParams['font.family'] = 'Times New Roman'
rcParams['mathtext.fontset'] = 'custom'
rcParams['mathtext.rm'] = 'Times New Roman'
rcParams['mathtext.it'] = 'Times New Roman:italic'
rcParams['mathtext.bf'] = 'Times New Roman:bold'
plt.rcParams['xtick.direction'] = 'in'
plt.rcParams['ytick.direction'] = 'in'
plt.rcParams['axes.edgecolor'] = 'black'
plt.rcParams['axes.linewidth'] = 0.8
plt.rcParams['font.size'] = 12

all_Cp_opt = data["all_Cp_opt"]
all_Cn_opt = data["all_Cn_opt"]
all_x0_opt = data["all_x0_opt"]
all_y0_opt = data["all_y0_opt"]
all_Cq = data["all_Cq"]
all_OCV_fit = data["all_OCV_fit"]
all_cell_cap = data["all_cell_cap"]
all_cell_ocv = data["all_cell_ocv"]
all_cell_vmea = data["all_cell_vmea"]
all_cells = data["all_cells"]
all_efcs = data["all_efcs"]


#%

colors = ['#DB6C6E', '#7BABD2', '#B3C786', '#B283B9','#F4BA61','#D5CA80','#9593C3']
all_Cp_opt = data["all_Cp_opt"]
all_Cn_opt = data["all_Cn_opt"]
all_x0_opt = data["all_x0_opt"]
all_y0_opt = data["all_y0_opt"]
all_Cq = data["all_Cq"]
all_OCV_fit = data["all_OCV_fit"]
all_cell_cap = data["all_cell_cap"]
all_cell_ocv = data["all_cell_ocv"]
all_cell_vmea = data["all_cell_vmea"]
all_cells = data["all_cells"]
all_efcs = data["all_efcs"]
all_Cli = [
    np.array(cp) * np.array(y0) + np.array(x0) * np.array(cn)
    for cp, y0, x0, cn in zip(all_Cp_opt, all_y0_opt, all_x0_opt, all_Cn_opt)
]

fig, axs = plt.subplots(1, 4, figsize=(26 / 2.54, 5.5 / 2.54), dpi=600, gridspec_kw={'wspace': 0.45})
plt.rcParams['xtick.direction'] = 'in'
plt.rcParams['ytick.direction'] = 'in'
plt.tick_params(top='on', right='on', which='both')
plt.tick_params(axis='both', which='both', bottom=False, top=False, left=False, right=False)
# color_map = plt.colormaps.get_cmap("coolwarm")
for i in range(len(all_cell_cap)):
    axs[3].plot(all_efcs[i],np.array(all_Cq[i]) * norminal_c,'-D',markersize=2, alpha=0.5, color=colors[3])
    axs[0].plot(all_efcs[i],np.array(all_Cp_opt[i]) * norminal_c,'-o',markersize=2, alpha=0.5, color=colors[0])
    axs[1].plot(all_efcs[i],np.array(all_Cn_opt[i]) * norminal_c,'-^',markersize=2, alpha=0.5, color=colors[4])
    axs[2].plot(all_efcs[i],np.array(all_Cli[i]) * norminal_c,'-s',markersize=2, alpha=0.5, color=colors[1])
    
axs[0].set_xlabel('EFC [Cycles]')
axs[0].set_ylabel(r'${\mathrm{C_p}}$ [Ah]')
axs[1].set_xlabel('EFC [Cycles]')
axs[1].set_ylabel(r'${\mathrm{C_n}}$ [Ah]')
axs[2].set_xlabel('EFC [Cycles]')
axs[2].set_ylabel(r'${\mathrm{Q_{li}}}$ [Ah]')     
axs[3].set_xlabel('EFC [Cycles]')
axs[3].set_ylabel(r'${\mathrm{C_{q}}}$ [Ah]')                   
plt.tight_layout()
plt.show()



#%%
all_indices = np.arange(0, 92)  
test_batteries_indices = all_indices[all_indices % 3 == 0]
train_batteries_indices = all_indices[all_indices % 3 != 0]

X_features = []
Y_targets = []
for i_idx, i in enumerate(train_batteries_indices):  #227, 38, 26, 100, 235 [7,216]
    efcs = all_efcs[i]
    print('load data for cell:',all_cells[i][0])
    # C/5
    Cq_5 = np.array(all_Cq[i]) 
    Cp_5 = np.array(all_Cp_opt[i]) 
    Cn_5 = np.array(all_Cn_opt[i])
    Cli_5 = np.array(all_Cli[i])
    cell_ocv = all_cell_ocv[i]
    cell_Vreal = cell_ocv[:,:,0]*4.2
    cell_Qreal = cell_ocv[:,:,1]
    cell_vmea = all_cell_vmea[i]
    cell_Vmea = cell_vmea[:,:,0]*4.2
    cell_Qmea = cell_vmea[:,:,1]
    
    fit_OCV = np.array(all_OCV_fit[i])
    
    real_OCV = cell_Vmea
    
    
    for j in range(len(Cq_5)):
        # 
        voc_fit = fit_OCV[j,:]
        voc_real = real_OCV[j,:]
        q_meas = cell_Qmea[j,:]
        v_meas = cell_Vmea[j,:]
        
        voc_fit = savgol_filter(voc_fit, window_length=31, polyorder=1)
        voc_real = savgol_filter(voc_real, window_length=31, polyorder=1)
        
        cp = Cp_5[j]
        cn = Cn_5[j]
        cq = Cq_5[j]
        cli = Cli_5[j]  # 
        # 
        X_i = np.column_stack([
            voc_fit,
            # q_meas,
            # v_meas,
            np.full_like(voc_fit, cp),
            np.full_like(voc_fit, cn),
            np.full_like(voc_fit, cli),
        ])
        
       #### flatten 
        X_features.append(X_i.flatten())  # shape: (6000,)
        Y_targets.append((voc_real - voc_fit))  # shape: (1000,)
        

final_model = joblib.load('saved_fittings/'+'final_model_C40_train1.pkl')        
model_name = 'dynamic_retrain_model_residual_C40_smooth_1.pkl'
X_all = np.stack(X_features)  # shape: (n_samples, 6000)
Y_all = np.stack(Y_targets)   # shape: (n_samples, 1000)
print(X_all.shape)

scaler_X = StandardScaler().fit(X_all)
scaler_Y = StandardScaler().fit(Y_all)

# X_train, X_val, Y_train, Y_val = train_test_split(X_all, Y_all, test_size=0.2, random_state=42)
X_train_std = scaler_X.transform(X_all)
Y_train_std = scaler_Y.transform(Y_all)

final_model.fit(X_train_std, Y_train_std)
joblib.dump(final_model, 'saved_fittings/'+model_name)
final_model = joblib.load('saved_fittings/'+'dynamic_retrain_model_residual_C40_smooth_1.pkl') 

#%%
X_features = []
Y_targets = []
for i_idx, i in enumerate(test_batteries_indices):  #227, 38, 26, 100, 235 [7,216]
    efcs = all_efcs[i]
    print('results for cell:',all_cells[i][0])
    # C/5
    Cq_5 = np.array(all_Cq[i]) 
    Cp_5 = np.array(all_Cp_opt[i]) 
    Cn_5 = np.array(all_Cn_opt[i])
    Cli_5 = np.array(all_Cli[i])
    cell_ocv = all_cell_ocv[i]
    cell_Vreal = cell_ocv[:,:,0]*4.2
    cell_Qreal = cell_ocv[:,:,1]
    cell_vmea = all_cell_vmea[i]
    cell_Vmea = cell_vmea[:,:,0]*4.2
    cell_Qmea = cell_vmea[:,:,1]
    
    fit_OCV = np.array(all_OCV_fit[i])
    
    real_OCV = cell_Vmea
    
    for j in range(len(Cq_5)):
        # 
        voc_fit = fit_OCV[j,:]
        voc_real = real_OCV[j,:]
        q_meas = cell_Qmea[j,:]
        v_meas = cell_Vmea[j,:]
        
        voc_fit = savgol_filter(voc_fit, window_length=31, polyorder=1)
        voc_real = savgol_filter(voc_real, window_length=31, polyorder=1)
        
        cp = Cp_5[j]
        cn = Cn_5[j]
        cq = Cq_5[j]
        cli = Cli_5[j]  # 
        # 
        X_i = np.column_stack([
            voc_fit,
            # q_meas,
            # v_meas,
            np.full_like(voc_fit, cp),
            np.full_like(voc_fit, cn),
            np.full_like(voc_fit, cli),
        ])
        
       #### flatten 
        X_features.append(X_i.flatten())  # shape: (6000,)
        Y_targets.append((voc_real - voc_fit))  # shape: (1000,)
        
        

X_all = np.stack(X_features)  # shape: (n_samples, 6000)
Y_all = np.stack(Y_targets)   # shape: (n_samples, 1000)

Y_test_pred_std = final_model.predict(scaler_X.transform(X_all))
Y_test_pred = scaler_Y.inverse_transform(Y_test_pred_std)

test_rmse = np.sqrt(mean_squared_error(Y_all, Y_test_pred))
print(f"Test RMSE: {test_rmse:.4f}")

cmap = mcolors.LinearSegmentedColormap.from_list('custom_cmap',['#E59693','#0073B1'])
orig_color_values = np.abs(Y_all[:,:].reshape(-1)*1000-Y_test_pred[:,:].reshape(-1)*1000)
norm = Normalize(vmin=orig_color_values.min(), vmax=orig_color_values.max())
color_values = norm(orig_color_values)
plt.figure(figsize=(5.5 / 2.54,6 / 2.54), dpi=600)
plt.ion()
plt.rcParams['xtick.direction'] = 'in'
plt.rcParams['ytick.direction'] = 'in'
plt.tick_params(top='on', right='on', which='both')
plt.tick_params(axis='both', which='both', bottom=False, top=False, left=False, right=False)
scatter1 = plt.scatter(Y_all[:,:].reshape(-1)*1000, Y_test_pred[:,:].reshape(-1)*1000, 
                      c=color_values, alpha=0.8, cmap='coolwarm', marker='o', linewidth=0.0, s=10, edgecolors=None)
plt.plot([min(Y_all[:,:].reshape(-1)*1000),max(Y_all[:,:].reshape(-1)*1000)],[min(Y_all[:,:].reshape(-1)*1000),max(Y_all[:,:].reshape(-1)*1000)],'--',color='grey',linewidth=1)
plt.xlabel('Real values [mV]')
plt.ylabel('Predictions [mV]')
cbar = plt.colorbar()
cbar.set_label('Absolute error')
plt.tick_params(bottom=False, left=False)
# cbar.set_label('Normalized Color values')
ticks = np.linspace(orig_color_values.min(), orig_color_values.max(), num=3)
tick_labels = ["{:.2f}".format(value) for value in ticks]
cbar.set_ticks(norm(ticks))
cbar.set_ticklabels(tick_labels)

ax_hist = inset_axes(
    plt.gca(),
    width="40%", height="30%",
    bbox_to_anchor=(0.44, 0.19, 0.62, 0.6),  # x0, y0, width, height 
    bbox_transform=plt.gcf().transFigure,
    loc='lower left'
)
ax_hist.hist(
    (Y_all[:, :].reshape(-1) * 1000 - Y_test_pred[:, :].reshape(-1) * 1000),
    bins=20, color='gray', edgecolor='black',linewidth=0.5
)

#
ax_hist.set_xlabel('Error [mV]', fontsize=6, labelpad=1)  
ax_hist.set_ylabel('Count', fontsize=6, labelpad=1)
ax_hist.tick_params(axis='both', which='major', labelsize=6)
ax_hist.ticklabel_format(axis='y', style='sci', scilimits=(0, 0))
ax_hist.yaxis.offsetText.set_fontsize(6)
plt.show()

fig2, ax2 = plt.subplots(figsize=(6/2.54, 6.5/2.54), dpi=600)
plt.ion()
plt.rcParams['xtick.direction'] = 'in'
plt.rcParams['ytick.direction'] = 'in'
plt.tick_params(top='on', right='on', which='both')
plt.tick_params(axis='both', which='both', bottom=False, top=False, left=False, right=False)
ax2.hist(
    (Y_all[:, :].reshape(-1) * 1000),
    bins=20, color=colors[0], edgecolor='black',linewidth=0.5,label='Fitted'
)

ax2.hist(
    (Y_all[:, :].reshape(-1) * 1000 - Y_test_pred[:, :].reshape(-1) * 1000),
    bins=12, color=colors[1], edgecolor='black',linewidth=0.5,alpha=0.9,label='Compensated'
)
ax2.ticklabel_format(axis='y', style='sci', scilimits=(0, 0))
ax2.set_xlabel('Error [mV]')  # 
ax2.set_ylabel('Count')
legend = ax2.legend(
    handlelength=1.0,
    handletextpad=0.2,
    labelspacing=0.3,
    frameon=False,
    ncol=1,
    bbox_transform=ax2.transAxes,
    loc='upper left',
    bbox_to_anchor=(-0.05, 1)
)
# 
plt.show()

#%%

for i_idx, i in enumerate([test_batteries_indices[20]]):  #20, 30, 26
    efcs = all_efcs[i]
    print('results for cell:',all_cells[i][0])
    # C/5
    cell_efc = np.array(all_efcs[i]) 
    Cq_5 = np.array(all_Cq[i]) 
    Cp_5 = np.array(all_Cp_opt[i]) 
    Cn_5 = np.array(all_Cn_opt[i])
    Cli_5 = np.array(all_Cli[i])
    cell_ocv = all_cell_ocv[i]
    cell_Vreal = cell_ocv[:,:,0]*4.2
    cell_Qreal = cell_ocv[:,:,1]
    cell_vmea = all_cell_vmea[i]
    cell_Vmea = cell_vmea[:,:,0]*4.2
    cell_Qmea = cell_vmea[:,:,1]
    
    fit_OCV = np.array(all_OCV_fit[i])
    
    real_OCV = cell_Vmea
    X_features = []
    Y_targets = []
    for j in range(len(Cq_5)):
        # 
        voc_fit = fit_OCV[j,:]
        voc_real = real_OCV[j,:]
        q_meas = cell_Qmea[j,:]
        v_meas = cell_Vmea[j,:]
        
        voc_fit = savgol_filter(voc_fit, window_length=31, polyorder=1)
        voc_real = savgol_filter(voc_real, window_length=31, polyorder=1)
        
        cp = Cp_5[j]
        cn = Cn_5[j]
        cq = Cq_5[j]
        cli = Cli_5[j]  # 
        # 
        X_i = np.column_stack([
            voc_fit,
            # q_meas,
            # v_meas,
            np.full_like(voc_fit, cp),
            np.full_like(voc_fit, cn),
            np.full_like(voc_fit, cli),
        ])
        
       #### flatten 
        X_features.append(X_i.flatten())  # shape: (6000,)
        Y_targets.append((voc_real - voc_fit))  # shape: (1000,)
        
        

    X_all = np.stack(X_features)  # shape: (n_samples, 6000)
    Y_all = np.stack(Y_targets)   # shape: (n_samples, 1000)
    
    Y_test_pred_std = final_model.predict(scaler_X.transform(X_all))
    Y_test_pred = scaler_Y.inverse_transform(Y_test_pred_std)
    
    test_rmse = np.sqrt(mean_squared_error(Y_all, Y_test_pred))
    print(f"Test RMSE: {test_rmse:.4f}")
    
    cmap = mcolors.LinearSegmentedColormap.from_list('custom_cmap',['#E59693','#0073B1'])
    orig_color_values = np.abs(Y_all[:,:].reshape(-1)*1000-Y_test_pred[:,:].reshape(-1)*1000)
    norm = Normalize(vmin=orig_color_values.min(), vmax=orig_color_values.max())
    color_values = norm(orig_color_values)
    plt.figure(figsize=(5.5 / 2.54,6 / 2.54), dpi=600)
    plt.ion()
    plt.rcParams['xtick.direction'] = 'in'
    plt.rcParams['ytick.direction'] = 'in'
    plt.tick_params(top='on', right='on', which='both')
    plt.tick_params(axis='both', which='both', bottom=False, top=False, left=False, right=False)
    scatter1 = plt.scatter(Y_all[:,:].reshape(-1)*1000, Y_test_pred[:,:].reshape(-1)*1000, 
                          c=color_values, alpha=0.9, cmap='coolwarm', marker='o', linewidth=0.0, s=10, edgecolors=None)
    plt.plot(Y_all[:,:].reshape(-1)*1000,Y_all[:,:].reshape(-1)*1000,'--',color='grey',linewidth=1)
    plt.xlabel('Real values [mV]')
    plt.ylabel('Predictions [mV]')
    cbar = plt.colorbar()
    cbar.set_label('Absolute error')
    plt.tick_params(bottom=False, left=False)
    # cbar.set_label('Normalized Color values')
    ticks = np.linspace(orig_color_values.min(), orig_color_values.max(), num=3)
    tick_labels = ["{:.2f}".format(value) for value in ticks]
    cbar.set_ticks(norm(ticks))
    cbar.set_ticklabels(tick_labels)

    ax_hist = inset_axes(
        plt.gca(),
        width="40%", height="30%",
        bbox_to_anchor=(0.44, 0.19, 0.62, 0.6),  # x0, y0, width, height
        bbox_transform=plt.gcf().transFigure,
        loc='lower left'
    )
    ax_hist.hist(
        (Y_all[:, :].reshape(-1) * 1000 - Y_test_pred[:, :].reshape(-1) * 1000),
        bins=20, color='gray', edgecolor='black',linewidth=0.5
    )

    # 
    ax_hist.set_xlabel('Error [mV]', fontsize=6, labelpad=1)  # 
    ax_hist.set_ylabel('Count', fontsize=6, labelpad=1)
    ax_hist.tick_params(axis='both', which='major', labelsize=6)
    plt.show()
    
    
    peak_vals_orig = []
    peak_vals_fit = []
    peak_vals_comp = []
    
    peak_pos_orig = []
    peak_pos_fit = []
    peak_pos_comp = []
    plt.figure(figsize=(10 / 2.54,6 / 2.54), dpi=600)
    plt.ion()
    plt.rcParams['xtick.direction'] = 'in'
    plt.rcParams['ytick.direction'] = 'in'
    plt.tick_params(top='on', right='on', which='both')
    plt.tick_params(axis='both', which='both', bottom=False, top=False, left=False, right=False)
    for j in range(0, len(Y_test_pred), 1): #([0,len(Y_test_pred)-1]): #range(len(Y_test_pred)): #len(Y_test_pred)
        Q = cell_Qmea[j,:]
        if Q[-1]<0.8 :
            break
        voc_real = savgol_filter(cell_Vmea[j,:], window_length=31, polyorder=1)
        voc_fit = savgol_filter(fit_OCV[j,:], window_length=31, polyorder=1)
        dv_dq_orig = gradient( voc_real, Q )
        dv_dq_fit = gradient( voc_fit, Q )
        # dv_dq_mea = gradient(measure_V, measure_Q)
        print(cell_Qmea[j,-1])
        dv_dq_compensate = gradient( voc_fit+Y_test_pred[j,:], Q )
        dv_dq_orig = savgol_filter(dv_dq_orig, window_length=91, polyorder=1)
        dv_dq_fit = savgol_filter(dv_dq_fit, window_length=91, polyorder=1)
        dv_dq_compensate = savgol_filter(dv_dq_compensate, window_length=91, polyorder=1)
        # plt.plot(cell_Qmea[j,:],cell_Vmea[j,:],'r')
        # plt.plot(cell_Qmea[j,:],fit_OCV[j,:],'b--')
        # plt.plot(cell_Qmea[j,:],fit_OCV[j,:]+Y_test_pred[j,:],'c--')
        
        color_scale = 1 / (0.04 * j + 1)
        plt.plot(Q, -dv_dq_orig, color=np.array(mcolors.to_rgb(colors[0])) * color_scale, label='Real' if j == 0 else None)
        plt.plot(Q, -dv_dq_fit, '--', color=np.array(mcolors.to_rgb(colors[5])) * color_scale, label='Fitted' if j == 0 else None)
        plt.plot(Q, -dv_dq_compensate, '--', color=np.array(mcolors.to_rgb(colors[1])) * color_scale, label='Compensated' if j == 0 else None)
        
        region_mask = (Q < 0.4) & (Q > 0.15)
        Q_focus = Q[region_mask]
        
        real_focus = -dv_dq_orig[region_mask]
        fit_focus = -dv_dq_fit[region_mask]
        comp_focus = -dv_dq_compensate[region_mask]
    
        peaks_real, _ = find_peaks(real_focus, prominence=0.05)
        peaks_fit, _ = find_peaks(fit_focus, prominence=0.05)
        peaks_comp, _ = find_peaks(comp_focus, prominence=0.05)
    
        if len(peaks_real) > 0:
            peak_vals_orig.append(real_focus[peaks_real[0]])
            peak_pos_orig.append(Q_focus[peaks_real[0]])
        else:
            peak_vals_orig.append(np.nan)
            peak_pos_orig.append(np.nan)
    
        if len(peaks_fit) > 0:
            peak_vals_fit.append(fit_focus[peaks_fit[0]])
            peak_pos_fit.append(Q_focus[peaks_fit[0]])
        else:
            peak_vals_fit.append(np.nan)
            peak_pos_fit.append(np.nan)
    
        if len(peaks_comp) > 0:
            peak_vals_comp.append(comp_focus[peaks_comp[0]])
            peak_pos_comp.append(Q_focus[peaks_comp[0]])
        else:
            peak_vals_comp.append(np.nan)
            peak_pos_comp.append(np.nan)    
        
    plt.ylim([0,3.5])
    plt.ylabel('dV/dQ [V/Ah]')
    plt.xlabel('Q [Ah]')
    # axs.legend(loc='best',
    #           handletextpad=0.1, 
    #           labelspacing=0.05,
    #           bbox_to_anchor=(0.35, 0.55),
    #           frameon=False)
    plt.show()
    # plt.show()

if j==len(Y_test_pred)-1:
    j=j+1
fig2, ax2 = plt.subplots(figsize=(5/2.54, 6.5/2.54), dpi=600)
plt.subplot(211)
plt.plot(cell_efc[0:j], peak_vals_orig, '-o', markersize=5, color=colors[0],label='Real Peak')
plt.plot(cell_efc[0:j], peak_vals_fit, '--s', markersize=5, color=colors[5],label='Fitted Peak')
plt.plot(cell_efc[0:j], peak_vals_comp, '--^', markersize=5, color=colors[1],label='Compensated Peak')
plt.ylabel('Peak')
# plt.xlabel('Test Sample Index j')
# plt.legend()
plt.ylim(np.nanmin(peak_vals_orig + peak_vals_fit + peak_vals_comp) * 0.9,
         np.nanmax(peak_vals_orig + peak_vals_fit + peak_vals_comp) * 1.1)
plt.gca().set_xticklabels([])  
plt.grid(True)

# plt.tight_layout()
# plt.show()
plt.subplot(212)
plt.plot(cell_efc[0:j], peak_pos_orig, '-o', markersize=5, color=colors[0],label='Real Peak')
plt.plot(cell_efc[0:j], peak_pos_fit, '--s', markersize=5, color=colors[5],label='Fitted Peak')
plt.plot(cell_efc[0:j], peak_pos_comp, '--^', markersize=5, color=colors[1],label='Compensated Peak')
plt.ylabel('Position')
plt.xlabel('EFC')
# plt.legend()
plt.grid(True)
plt.ylim(np.nanmin(peak_pos_orig + peak_pos_fit + peak_pos_comp) * 0.9,
         np.nanmax(peak_pos_orig + peak_pos_fit + peak_pos_comp) * 1.1)
# plt.tight_layout()
plt.subplots_adjust(hspace=0.25)
plt.show()


#%% Accuracy comparison

model_names = ['dynamic_retrain_model_residual_C40_1.pkl','dynamic_retrain_model_residual_C40_smooth_1.pkl']
for model_name in model_names:
    print('Test results for:', model_name)
    X_features = []
    Y_targets = []
    for i_idx, i in enumerate(test_batteries_indices):  #227, 38, 26, 100, 235 [7,216]
        efcs = all_efcs[i]
        # print('results for cell:',all_cells[i][0])
        # C/5
        Cq_5 = np.array(all_Cq[i]) 
        Cp_5 = np.array(all_Cp_opt[i]) 
        Cn_5 = np.array(all_Cn_opt[i])
        Cli_5 = np.array(all_Cli[i])
        cell_ocv = all_cell_ocv[i]
        cell_Vreal = cell_ocv[:,:,0]*4.2
        cell_Qreal = cell_ocv[:,:,1]
        cell_vmea = all_cell_vmea[i]
        cell_Vmea = cell_vmea[:,:,0]*4.2
        cell_Qmea = cell_vmea[:,:,1]
        
        fit_OCV = np.array(all_OCV_fit[i])
        
        real_OCV = cell_Vmea
        
        for j in range(len(Cq_5)):
            # 
            voc_fit = fit_OCV[j,:]
            voc_real = real_OCV[j,:]
            q_meas = cell_Qmea[j,:]
            v_meas = cell_Vmea[j,:]
            
            if 'smooth' in model_name:
                voc_fit = savgol_filter(voc_fit, window_length=31, polyorder=1)
                voc_real = savgol_filter(voc_real, window_length=31, polyorder=1)
            
            cp = Cp_5[j]
            cn = Cn_5[j]
            cq = Cq_5[j]
            cli = Cli_5[j]  # 
            #
            X_i = np.column_stack([
                voc_fit,
                # q_meas,
                # v_meas,
                np.full_like(voc_fit, cp),
                np.full_like(voc_fit, cn),
                np.full_like(voc_fit, cli),
            ])
            
           #### flatten 
            X_features.append(X_i.flatten())  # shape: (6000,)
            Y_targets.append((voc_real - voc_fit))  # shape: (1000,)
            
    X_all = np.stack(X_features)  # shape: (n_samples, 6000)
    Y_all = np.stack(Y_targets)   # shape: (n_samples, 1000)
    
    final_model = joblib.load('saved_fittings/'+model_name)
    
    Y_test_pred_std = final_model.predict(scaler_X.transform(X_all))
    Y_test_pred = scaler_Y.inverse_transform(Y_test_pred_std)
    
    fit_ocv = X_all[:, 0::4]
    orig_ocv = fit_ocv + Y_all
    pre_ocv = fit_ocv + Y_test_pred
    
    
    test_rmse = np.sqrt(mean_squared_error(pre_ocv.reshape(-1)*1000, orig_ocv.reshape(-1)*1000))
    test_maae = np.max(abs(orig_ocv.reshape(-1)*1000- pre_ocv.reshape(-1)*1000))
    test_mae = mean_absolute_error(pre_ocv.reshape(-1)*1000, orig_ocv.reshape(-1)*1000)
    test_r2 = r2_score(orig_ocv.reshape(-1), pre_ocv.reshape(-1))


    fit_rmse = np.sqrt(mean_squared_error(fit_ocv.reshape(-1)*1000, orig_ocv.reshape(-1)*1000))
    fit_maae = np.max(abs(orig_ocv.reshape(-1)*1000- fit_ocv.reshape(-1)*1000))
    fit_mae = mean_absolute_error(fit_ocv.reshape(-1)*1000, orig_ocv.reshape(-1)*1000)
    fit_r2 = r2_score(orig_ocv.reshape(-1), fit_ocv.reshape(-1))
    
    print(f"Test RMSE: {test_rmse:.4f}", f"Test MAE: {test_mae:.4f}", f"Test MaxAE: {test_maae:.4f}", f"Test R2: {test_r2:.4f}")
    print(f"Fit RMSE: {fit_rmse:.4f}", f"Fit MAE: {fit_mae:.4f}", f"Fit MaxAE: {fit_maae:.4f}", f"Fit R2: {fit_r2:.4f}")



